{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "13d5c640",
   "metadata": {},
   "source": [
    "## 深度学习特征\n",
    "\n",
    "提取CT、MRI、内镜、Xray等影像数据的深度学习特征。\n",
    "\n",
    "### Onekey步骤\n",
    "\n",
    "1. 将待提取的数据转化成jpg，可以参考使用OKT-convert2jpg或者OKT-crop_max_roi两个Onekey工具。\n",
    "2. 获取到指定目录的所有图像数据。\n",
    "3. 选择要提取什么样的模型的深度学习特征，目前Onekey支持主流的深度学习模型。（可以考虑使用Onekey进行迁移学习）\n",
    "4. 提取特征，保存特征文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3615a845",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('深度学习特征提取')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c603d40",
   "metadata": {},
   "source": [
    "### 获取2D图像数据\n",
    "\n",
    "1. 对于本身是2D的图像，例如各种内镜以及自然光图像，可以尝试把它转成jpg（非必须操作），需要使用的OKT-convert2jpg工具\n",
    "2. 对于3D的图像，需要根据Mask将其最大的ROI区域crop出来，需要使用的是OKT-crop_max_roi工具。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0344c45d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### OKT-convert2jpg\n",
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('OKT-crop_max_roi')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "754981f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### OKT-convert2jpg\n",
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('OKT-convet2jpg')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b21b469b",
   "metadata": {},
   "source": [
    "## 获取待提取特征的文件\n",
    "\n",
    "提供两种批量处理的模式：\n",
    "1. 目录模式，提取指定目录下的所有jpg文件的特征。\n",
    "2. 文件模式，待提取的数据存储在文件中，每行一个样本。\n",
    "\n",
    "当然也可以在最后自己指定手动提取指定若干文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68d07b97",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### 获取数据\n",
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('深度学习特征提取|获取数据')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8dee942",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "# 目录模式\n",
    "mydir = r'C:\\Users\\onekey\\Desktop\\demo\\crop'\n",
    "# mydir = r'C:\\Users\\onekey\\Project\\OnekeyDS\\CT\\full'\n",
    "directory = os.path.expanduser(mydir)\n",
    "test_samples = [os.path.join(directory, p) for p in os.listdir(directory) if p.endswith('.png') or p.endswith('.jpg')]\n",
    "\n",
    "# 文件模式\n",
    "# test_file = ''\n",
    "# with open(test_file) as f:\n",
    "#     test_samples = [l.strip() for l in f.readlines()]\n",
    "\n",
    "# 自定义模式\n",
    "# test_sampleses = ['path2jpg']\n",
    "test_samples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26847144",
   "metadata": {},
   "source": [
    "## 确定提取特征\n",
    "\n",
    "通过关键词获取要提取那一层的特征。\n",
    "\n",
    "### 支持的模型名称\n",
    "\n",
    "模型名称替换代码中的 `model_name`变量的值。\n",
    "\n",
    "| **模型系列** | **模型名称**                                                 |\n",
    "| ------------ | ------------------------------------------------------------ |\n",
    "| AlexNet      | alexnet                                                      |\n",
    "| VGG          | vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19_bn, vgg19 |\n",
    "| ResNet       | resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2 |\n",
    "| DenseNet     | densenet121, densenet169, densenet201, densenet161           |\n",
    "| Inception    | googlenet, inception_v3                                      |\n",
    "| SqueezeNet   | squeezenet1_0, squeezenet1_1                                 |\n",
    "| ShuffleNetV2 | shufflenet_v2_x2_0, shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5 |\n",
    "| MobileNet    | mobilenet_v2, mobilenet_v3_large, mobilenet_v3_small         |\n",
    "| MNASNet      | mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3              |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64faba65",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### 获取数据\n",
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('深度学习特征提取|确定模型和特征')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e8d607a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp2 import extract, print_feature_hook, reg_hook_on_module, \\\n",
    "    init_from_model, init_from_onekey\n",
    "\n",
    "model_name = 'googlenet'\n",
    "# model, transformer, device = init_from_model(model_name=model_name)\n",
    "model, transformer, device = init_from_onekey(r'C:\\Users\\onekey\\Desktop\\demo\\src\\models\\dl\\4\\resnet18\\viz')\n",
    "for n, m in model.named_modules():\n",
    "    print('Feature name:', n, \"|| Module:\", m)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87f29370",
   "metadata": {},
   "source": [
    "## 提取特征\n",
    "\n",
    "`Feature name:` 之后的名称为要提取的特征名，例如`layer3.0.conv2`, 一般深度学习特征提取最后一层，例如`avgpool`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec54b1a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### 提取特征\n",
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('深度学习特征提取|提取特征')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "541bfc3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "feature_name = 'avgpool'\n",
    "with open('feature.csv', 'w') as outfile:\n",
    "    hook = partial(print_feature_hook, fp=outfile)\n",
    "    find_num = reg_hook_on_module(feature_name, model, hook)\n",
    "    results = extract(test_samples, model, transformer, device, fp=outfile)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d3cbd6f",
   "metadata": {},
   "source": [
    "## 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9182421a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "#### 特征读取\n",
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('深度学习特征提取|特征读取')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7ac1309",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "features = pd.read_csv('feature.csv', header=None)\n",
    "features.columns=['ID'] + list(features.columns[1:])\n",
    "features.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9634e5f2",
   "metadata": {},
   "source": [
    "### 深度特征压缩\n",
    "\n",
    "深度学习特征压缩，注意压缩到的维度需要小于样本数\n",
    "\n",
    "```python\n",
    "def compress_df_feature(features: pd.DataFrame, dim: int, not_compress: Union[str, List[str]] = None,\n",
    "                        prefix='') -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    压缩深度学习特征\n",
    "    Args:\n",
    "        features: 特征DataFrame\n",
    "        dim: 需要压缩到的维度，此值需要小于样本数\n",
    "        not_compress: 不进行压缩的列。\n",
    "        prefix: 所有特征的前缀。\n",
    "\n",
    "    Returns:\n",
    "\n",
    "    \"\"\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c649a10",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import compress_df_feature\n",
    "\n",
    "cm_features = compress_df_feature(features=features, dim=64, prefix='DL_', not_compress='ID')\n",
    "cm_features.to_csv('compress_features.csv', header=True, index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d6c9677",
   "metadata": {},
   "source": [
    "### 迁移学习\n",
    "\n",
    "使用Onekey，提取基于迁移学习的模型特征。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9783b17b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### 特征读取\n",
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('深度学习特征提取|Onekey迁移学习')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1285696",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
